{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# EasyEdit Example with T5 and KN\n",
    "\n",
    "The following example shows how to use KN edit with T5.\n",
    "\n",
    "## Preparation\n",
    "\n",
    "Except for installing all the necessary dependencies, you need to download the pre-trained language model. The default hparams in the 'hparams' folder use huggingface cache in default. You can change the 'model_name' in hparam file to the model name to let huggingface automatically download the model for you. In our example, it means changing the 'model_name' to t5-3B."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```python\n",
    "alg_name: \"KN\"\n",
    "model_name: \"./hugging_cache/t5-3B\"\n",
    "device: 0\n",
    "model_parallel: false\n",
    "\n",
    "lr_scale: 1.0\n",
    "n_toks: 10\n",
    "refine: false\n",
    "batch_size: 1\n",
    "steps: 20\n",
    "adaptive_threshold: 0.2\n",
    "p: 0.4\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/mnt/8t/xkw/EasyEdit\n"
     ]
    }
   ],
   "source": [
    "%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-12-01 14:36:21,274 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "12/01/2024 14:36:21 - INFO - easyeditor.editors.editor -   Instantiating model\n",
      "You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n",
      "/mnt/8t/xkw/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
      "  warnings.warn(\n",
      "100%|██████████| 1/1 [00:01<00:00,  1.07s/it]\n",
      "Getting coarse neurons for each prompt...: 100%|██████████| 1/1 [08:02<00:00, 482.27s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "61 coarse neurons found - refining\n",
      "61 neurons remaining after refining\n",
      "\n",
      "Before modification - groundtruth probability: tensor([2.5646e-29], device='cuda:0')\n",
      "Argmax completion: `</s></s><extra_id_5>`\n",
      "Argmax prob: 0.21098077093542403\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [08:03<00:00, 483.76s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "After modification - groundtruth probability: tensor([4.9062e-22], device='cuda:0')\n",
      "Argmax completion: `dendendenden`\n",
      "Argmax prob: 0.9392140688144989\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "2024-12-01 14:47:10,575 - easyeditor.editors.editor - INFO - 0 editing: Who was the designer of Lahti Town Hall? -> Joe Biden  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who was the designer of Lahti Town Hall?', 'target_new': 'Joe Biden', 'ground_truth': '<|endoftext|>', 'portability': {}, 'locality': {}, 'subject': 'Lahti Town Hall'}, 'post': {'rewrite_acc': [0.3333333432674408], 'locality': {}, 'portability': {}}}\n",
      "12/01/2024 14:47:10 - INFO - easyeditor.editors.editor -   0 editing: Who was the designer of Lahti Town Hall? -> Joe Biden  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who was the designer of Lahti Town Hall?', 'target_new': 'Joe Biden', 'ground_truth': '<|endoftext|>', 'portability': {}, 'locality': {}, 'subject': 'Lahti Town Hall'}, 'post': {'rewrite_acc': [0.3333333432674408], 'locality': {}, 'portability': {}}}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics Summary:  {'pre': {'rewrite_acc': 0.0}, 'post': {'rewrite_acc': 0.3333333432674408}}\n",
      "[{'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who was the designer of Lahti Town Hall?', 'target_new': 'Joe Biden', 'ground_truth': '<|endoftext|>', 'portability': {}, 'locality': {}, 'subject': 'Lahti Town Hall'}, 'post': {'rewrite_acc': [0.3333333432674408], 'locality': {}, 'portability': {}}}]\n",
      "<class 'transformers.models.t5.modeling_t5.T5ForConditionalGeneration'>\n"
     ]
    }
   ],
   "source": [
    "from easyeditor import KNHyperParams\n",
    "from easyeditor import BaseEditor\n",
    "import os\n",
    "\n",
    "hparams = KNHyperParams.from_hparams(\"./hparams/KN/t5-3B.yaml\")\n",
    "\n",
    "prompts = ['Who was the designer of Lahti Town Hall?']\n",
    "target_new = ['Joe Biden']\n",
    "subject = ['Lahti Town Hall']\n",
    "\n",
    "editor=BaseEditor.from_hparams(hparams)\n",
    "metrics, edited_model, _ = editor.edit(\n",
    "    prompts=prompts,\n",
    "    target_new=target_new,\n",
    "    subject=subject,\n",
    "    sequential_edit=True\n",
    ")\n",
    "print(metrics)\n",
    "print(type(edited_model))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare edited model and original model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5ForConditionalGeneration\n",
    "from transformers import T5Tokenizer\n",
    "\n",
    "device = 1\n",
    "tokenizer = T5Tokenizer.from_pretrained(\"./hugging_cache/t5-3B\")\n",
    "origin_model = T5ForConditionalGeneration.from_pretrained(\"./hugging_cache/t5-3B\").to(f\"cuda:{device}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prompt: Who was the designer of Lahti Town Hall?\n",
      "Pre-Edit  Output: was the designer of Lahti Town Hall??? Who was the designer\n",
      "Post-Edit Output: dendendenden\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "\n",
    "batch = tokenizer(prompts, return_tensors='pt', padding=True)\n",
    "\n",
    "origin_output = origin_model.generate(\n",
    "    input_ids=batch['input_ids'].to(origin_model.device),\n",
    "    attention_mask=batch['attention_mask'].to(origin_model.device),\n",
    "    max_new_tokens=20\n",
    ")\n",
    "\n",
    "output = edited_model.generate(\n",
    "    input_ids=batch['input_ids'].to(edited_model.device),\n",
    "    attention_mask=batch['attention_mask'].to(edited_model.device),\n",
    "    max_new_tokens=5\n",
    ")\n",
    "\n",
    "\n",
    "for i in range(len(prompts)):\n",
    "    print(f'Prompt: {prompts[i]}')\n",
    "    print(f'Pre-Edit  Output: {tokenizer.decode(origin_output[i], skip_special_tokens=True)}')\n",
    "    print(f'Post-Edit Output: {tokenizer.decode(output[i], skip_special_tokens=True)}')\n",
    "    print('--'*50 )\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "EasyEdit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
